from bluepyopt.ephys.responses import TimeVoltageResponse
from bluepyparallel import init_parallel_factory
import yaml
from functools import partial

from pathlib import Path
import matplotlib.pyplot as plt
import pickle

# from emodel_generalisation.utils import plot_traces
from emodel_generalisation.model.modifiers import synth_soma, synth_axon
from emodel_generalisation.mcmc import load_chains
from emodel_generalisation.mcmc import save_selected_emodels
from itertools import cycle
import json
from matplotlib.backends.backend_pdf import PdfPages
import pandas as pd
from emodel_generalisation.model.access_point import AccessPoint
from emodel_generalisation.model.evaluation import feature_evaluation
from emodel_generalisation.utils import get_combo_hash
from emodel_generalisation.utils import get_feature_df
from datareuse import Reuse
import extra_features


def plot_traces(trace_df, trace_path="traces", pdf_filename="traces.pdf"):
    """Plot traces from df, with highlighs on rows with trace_highlight = True.

    Args:
        trace_df (DataFrame): contains list of combos with traces to plot
        trace_path (str): path to folder with traces in .pkl
        pdf_filename (str): name of pdf to save
    """
    COLORS = cycle(["r"] + [f"C{i}" for i in range(10)])
    trace_df = trace_df.copy()  # prevents annyoing panda warnings
    if "trace_highlight" not in trace_df.columns:
        trace_df["trace_highlight"] = True
    for index in trace_df.index:
        if trace_df.loc[index, "trace_highlight"]:
            c = next(COLORS)

        if "trace_data" in trace_df.columns:
            trace_path = trace_df.loc[index, "trace_data"]
        else:
            combo_hash = get_combo_hash(trace_df.loc[index])
            trace_path = Path(trace_path) / ("trace_id_" + str(combo_hash) + ".pkl")

        with open(trace_path, "rb") as f:
            trace = pickle.load(f)
            if isinstance(trace, list):
                trace = trace[1]  # newer version the response are here
            for protocol, response in trace.items():

                if protocol.startswith("Step_ReboundBurst"):
                    if isinstance(response, TimeVoltageResponse):
                        plt.figure(protocol, figsize=(30, 7))
                        # plt.gca().set_xlim(4500, 14000)
                        plt.plot(response["time"], response["voltage"], c="k", lw=1)

    with PdfPages(pdf_filename) as pdf:
        for fig_id in plt.get_fignums():
            fig = plt.figure(fig_id)
            plt.legend(loc="best")
            plt.suptitle(fig.get_label())
            pdf.savefig()
            plt.close()


if __name__ == "__main__":
    parallel_factory = init_parallel_factory("multiprocessing")
    # parallel_factory = init_parallel_factory("dask_dataframe")
    emodel = "simplest"

    mcmc_df = load_chains("../../mcmc_run/run_df.csv", base_path="../../mcmc_run")
    mcmc_df = mcmc_df[mcmc_df.cost < 2.0]
    mask = (
        mcmc_df["features"]["Step_ReboundBurst_burst_high.soma.v.all_burst_number"]
        <= mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"]
    )
    mcmc_df = mcmc_df[mask].reset_index(drop=True)

    mask_runaway = (
        (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.burst_runaway"] < 0.05)
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"] > 20000)
        & (mcmc_df["features"]["Step_ReboundBurst_burst_high.soma.v.tonic_after_burst"] <= 2.0)
    )
    mask_spp = (
        ~mask_runaway
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] > 5)
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.tonic_after_burst"] <= 2.0)
    ) & (mcmc_df["features"]["Step_ReboundBurst_burst_high.soma.v.tonic_after_burst"] <= 2.0)

    mask_ecel = (
        (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] < 5)
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.tonic_after_burst"] == 0.0)
        & (mcmc_df["features"]["Step_ReboundBurst_burst_high.soma.v.tonic_after_burst"] == 0.0)
    )
    mask_tonic = (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.tonic_after_burst"] < 10) & (
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.tonic_after_burst"] > 5
    )
    fig, ax = plt.subplots(figsize=(5, 3))
    for mask, name in [
        (mask_ecel, "ecel"),
        (mask_spp, "spp"),
        (mask_runaway, "runaway"),
        # (mask_tonic, "tonic"),
    ]:
        df = mcmc_df[mask].sample(1, random_state=42).reset_index()
        df.to_csv(f"params_{name}.csv")
        save_selected_emodels(df, df.index, emodel=emodel, final_path=f"final_{name}.json")

        access_point = AccessPoint(
            emodel_dir=".",
            recipes_path="config/recipes.json",
            final_path=f"final_{name}.json",
            with_seeds=True,
        )
        exemplar_data = yaml.safe_load(open("../../mcmc_run/exemplar_data.yaml"))
        access_point.morph_path = exemplar_data["paths"]["all"]
        access_point.settings["morph_modifiers"] = [
            partial(synth_soma, params=exemplar_data["soma"], scale=1.0),
            partial(synth_axon, params=exemplar_data["ais"]["popt"], scale=1.0),
        ]

        df = pd.DataFrame()
        final = json.load(open(f"final_{name}.json"))
        for i, emodel in enumerate(final):
            df.loc[i, "name"] = name
            df.loc[i, "emodel"] = emodel

        Path("traces").mkdir(exist_ok=True)
        with Reuse(f"eval_{name}.csv") as reuse:
            df = reuse(
                feature_evaluation,
                df,
                access_point,
                parallel_factory=parallel_factory,
                trace_data_path="traces",
                # record_ions_and_currents=True,
            )
        plot_traces(df, pdf_filename=f"traces_{name}.pdf")
        plt.close("all")

        d = get_feature_df(df)

        print(d)
        columns = [c for c in d.columns if len(c.split(".")[0].split("_")) == 4]
        df = pd.DataFrame()
        i = 0
        for col in columns:
            if col.endswith(".voltage_base"):
                df.loc[i, "burst_number"] = d[
                    ".".join(col.split(".")[:-1] + ["all_burst_number"])
                ].tolist()[0]
                df.loc[i, "holding_voltage"] = d[col].tolist()[0]
                i += 1
        print(df)
        df.plot(x="holding_voltage", y="burst_number", ax=ax, marker="+", label=name)
        ax.set_xlabel("holding_voltage")
        ax.set_ylabel("burst_number")
        ax.legend(loc="best")
        ax.axis([-80, -50, 0, 37])
        fig.tight_layout()
        fig.savefig("burst_scan.pdf")
